# Owner(s): ["oncall: distributed"]

import contextlib
import io
from copy import deepcopy
from collections import OrderedDict
from itertools import product
import functools

import torch
from torch import nn
from torch.cuda.amp import autocast
import torch.nn.parallel as dp
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, onlyCUDA, skipMeta
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle_if
import torch.nn.functional as F

NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL")

# batched grad doesn't support data parallel
gradcheck = functools.partial(gradcheck, check_batched_grad=False)
_assertGradAndGradgradChecks = functools.partial(_assertGradAndGradgradChecks, check_batched_grad=False)

class TestDataParallel(TestCase):

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_buffers_requiring_grad(self):
        class TestModule(nn.Module):
            def __init__(self, t):
                super().__init__()
                self.register_buffer('t_rg', t)
                self.register_buffer('t_not_rg', t.clone().detach())

            def forward(self, x):
                return x * self.t_rg + self.t_not_rg

        m = TestModule(torch.randn(100, device='cuda', requires_grad=True, dtype=torch.double))
        self.assertTrue(m.t_rg.requires_grad)

        dpm = nn.DataParallel(m, [0, 1])
        inp = torch.randn(2, 100, device='cuda', dtype=torch.double)

        def fn(t):
            return dpm(inp)

        gradcheck(fn, (m.t_rg,))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_rnn(self):

        class TestModule(torch.nn.Module):

            def __init__(self):
                super().__init__()
                self.rnn = torch.nn.LSTM(300, 1024, 1, batch_first=True, bidirectional=True)

            def forward(self, x):
                self.rnn.flatten_parameters()
                return self.rnn(x)

        def step(model):
            opt = torch.optim.SGD(model.parameters(), lr=10)
            input = torch.ones(4, 4, 300).to(0)
            output = model(input)
            loss = F.mse_loss(output[0], torch.zeros_like(output[0]))
            loss.backward()
            opt.step()

        with torch.no_grad():
            model = TestModule().to(0)
            model_dp = torch.nn.DataParallel(deepcopy(model))

            # make sure DP does not crash when grad is disabled.
            # See #21108
            model_dp(torch.rand(2, 4, 300).to(0))

        step(model)
        step(model_dp)

        for p1, p2 in zip(model.parameters(), model_dp.parameters()):
            self.assertTrue(p1.allclose(p2))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_lazy_linear(self):

        with self.assertRaisesRegex(ValueError, 'Attempted to use an uninitialized parameter'):
            model_dp = torch.nn.DataParallel(torch.nn.LazyLinear(10).to(0))
            model_dp(torch.rand(10, 10).to(0))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_parallel_apply(self):
        l1 = nn.Linear(10, 5).to("cuda:0", torch.float)
        l2 = nn.Linear(10, 5).to("cuda:1", torch.float)
        i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float)
        i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float)
        expected1 = l1(i1)
        expected2 = l2(i2)
        modules = (l1, l2)
        expected_outputs = (expected1, expected2)

        # each input can be either a collection of positional arguments
        #                       or an object representing the single argument
        for inputs in [((i1,), (i2,)), (i1, i2)]:
            outputs = dp.parallel_apply(modules, inputs, None)
            for out, expected in zip(outputs, expected_outputs):
                self.assertEqual(out, expected)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_parallel_apply_autocast(self):
        l1 = nn.Linear(10, 5).to("cuda:0", torch.float)
        l2 = nn.Linear(10, 5).to("cuda:1", torch.float)
        i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float)
        i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float)
        with autocast():
            expected1 = l1(i1)
            expected2 = l2(i2)
        modules = (l1, l2)
        expected_outputs = (expected1, expected2)

        # each input can be either a collection of positional arguments
        #                       or an object representing the single argument
        for inputs in [((i1,), (i2,)), (i1, i2)]:
            with autocast():
                outputs = dp.parallel_apply(modules, inputs, None)
            for out, expected in zip(outputs, expected_outputs):
                self.assertEqual(out, expected)

    @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "CUDA unavailable")
    def test_parallel_apply_passes_exception(self):
        # we define and instantiate a module that will throw a KeyError
        class TestModule(nn.Module):

            def forward(self, *args):
                return {}['wonderful']

        l1 = TestModule().to("cuda", torch.float)
        # and check that parallel_apply passes on the exception
        # (we can use a single device twice for this test)
        with self.assertRaisesRegex(KeyError,
                                    'Caught KeyError in replica \\d '
                                    'on device 0.\nOriginal Traceback'
                                    '[\\s\\S]+wonderful'):
            dp.parallel_apply(modules=(l1, l1), inputs=(None, None))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_multiple_input(self):
        class TestModule(nn.Module):

            def forward(self, var1, var2, float1, var3=None):
                if var3 is None:
                    return float1 * (var1 * var2)
                else:
                    return float1 * (var1 * var2 + var3)

        m = TestModule()
        var1 = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
        var2 = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
        var3 = torch.randn(5, 5, dtype=torch.float, requires_grad=False)

        float1 = torch.randn(1).item()

        expected = m(var1, var2, float1)
        loss = expected.sum()
        loss.backward()
        gvar1_exp = var1.grad.clone()
        gvar2_exp = var2.grad.clone()

        def local_test(out):
            with torch.no_grad():
                var1.grad.fill_(0.0)
                var2.grad.fill_(0.0)
            loss = out.sum()
            loss.backward()
            self.assertEqual(out, expected)
            self.assertEqual(gvar1_exp, var1.grad)
            self.assertEqual(gvar2_exp, var2.grad)

        out = dp.data_parallel(m, (var1, var2, float1), (0, 1))
        local_test(out)

        out = dp.data_parallel(m, (var1, var2, float1), (1, 0))
        local_test(out)

        out = dp.data_parallel(m, (var1, var2, float1), (0,))
        local_test(out)

        with torch.no_grad():
            var1.grad.fill_(0.0)
            var2.grad.fill_(0.0)
        expected = m(var1, var2, float1, var3=var3)
        loss = expected.sum()
        loss.backward()
        gvar1_exp = var1.grad.clone()
        gvar2_exp = var2.grad.clone()

        dpm = nn.DataParallel(TestModule())
        out = dpm(var1, var2, float1, var3=var3)
        local_test(out)

        dpm = nn.DataParallel(TestModule(), device_ids=[0])
        out = dpm(var1, var2, float1, var3=var3)
        local_test(out)

        kwarg_wrap = {'var3': var3}
        out = dp.data_parallel(
            m, (var1, var2, float1), (0, 1), module_kwargs=kwarg_wrap)
        local_test(out)

        out = dp.data_parallel(
            m, (var1, var2, float1), (0,), module_kwargs=kwarg_wrap)
        local_test(out)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_small_back(self):
        l = nn.Linear(10, 5).float().cuda()
        i = torch.randn(20, 10, dtype=torch.float, device="cuda")
        out = dp.data_parallel(l, i, (0, 1))
        self.assertEqual(out, l(i))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_model_device(self):
        r"""Test device[0] check at forward time.
        """
        l = nn.Linear(2, 2)
        inp = torch.randn(2, 2)
        inp_cuda0 = inp.cuda(0)
        inp_cuda1 = inp.cuda(1)

        error_msg = "module must have its parameters and buffers on device {}"

        @contextlib.contextmanager
        def dummy_ctx_manager():
            yield

        def test(inner_m, dp_device, inp, device_ids, should_fail):
            if device_ids is None:
                device_ids = list(range(torch.cuda.device_count()))

            if isinstance(device_ids[0], torch.device):
                expect_device = device_ids[0]
            else:
                expect_device = torch.device(f"cuda:{device_ids[0]}")

            if should_fail:
                def assert_correct():
                    return self.assertRaisesRegex(RuntimeError, error_msg.format(expect_device))
            else:
                assert_correct = dummy_ctx_manager

            # test DataParallel module
            dpm = nn.DataParallel(inner_m, device_ids)
            if dp_device is not None:
                dpm = dpm.to(dp_device)

            with assert_correct():
                dpm(inp)

            # test functional
            with assert_correct():
                nn.parallel.data_parallel(inner_m.to(dp_device), inp, device_ids)

        test(l.to('cpu'), None, inp, None, should_fail=True)
        test(l.cuda(1), None, inp_cuda0, None, should_fail=True)
        test(l.cuda(), None, inp_cuda0, [1, 0], should_fail=True)

        test(l.cuda(), None, inp_cuda0, None, should_fail=False)
        test(l.cpu(), 'cuda', inp_cuda0, None, should_fail=False)
        test(l.cuda(1), None, inp_cuda1, [1, 0], should_fail=False)
        test(l.cpu(), 'cuda:1', inp_cuda1, [1, 0], should_fail=False)

        s = nn.Sequential(l.cpu())
        test(s, None, inp, None, should_fail=True)
        test(s, None, inp, [0, 1], should_fail=True)
        test(s, None, inp, [1, 0], should_fail=True)

        s = nn.Sequential(deepcopy(l).cpu(), l.cuda())
        test(s, None, inp, None, should_fail=True)
        test(s, None, inp, [0, 1], should_fail=True)
        test(s, None, inp, [1, 0], should_fail=True)

        s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1))
        test(s, None, inp, None, should_fail=True)
        test(s, None, inp, [0, 1], should_fail=True)
        test(s, None, inp, [1, 0], should_fail=True)

        s = nn.Sequential(l.cuda(), deepcopy(l).cuda())
        test(s, None, inp, None, should_fail=False)
        test(s, None, inp, [0, 1], should_fail=False)
        test(s, None, inp, [1, 0], should_fail=True)
        test(s.cpu(), None, inp, [1, 0], should_fail=True)
        test(s.cuda(1), None, inp, [1, 0], should_fail=False)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_model_no_refcycles(self):
        # Python 2.7 will create reference cycles with the following
        # Module on multiple GPUs, but Python 3 shouldn't unless
        # there are refcycles on the PyTorch side (or the defined module)
        import gc

        class Model(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(1, 1)

            def forward(self, x):
                return self.linear(x)

        gc.collect()
        model = nn.DataParallel(Model().cuda())
        data = torch.randn(1, device="cuda")
        model(data)

        refcycles = gc.collect()
        self.assertEqual(refcycles, 0)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_no_grad(self):
        test = self

        class Layer(nn.Module):
            def forward(self, x):
                test.assertFalse(torch.is_grad_enabled())
                return x

        l = Layer()
        i = torch.randn(20, 10, dtype=torch.float, device="cuda")
        with torch.no_grad():
            dp.data_parallel(l, i, (0, 1))
        self.assertRaises(AssertionError, lambda: dp.data_parallel(l, i, (0, 1)))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel(self):
        l = nn.Linear(10, 5).float().cuda()
        i = torch.randn(20, 10, dtype=torch.float, device="cuda:1")
        l.cuda(1)
        expected_out = l(i)
        loss = expected_out.sum()
        loss.backward()
        expected_grads = []
        for param in l.parameters():
            expected_grads.append(param.grad.clone())
        dev_ids_list = [(0, 1), (1, 0)]
        for dev_id in dev_ids_list:
            with torch.cuda.device(dev_id[0]):
                l.cuda()
                l.zero_grad()
                out = dp.data_parallel(l, i, dev_id)
                loss = out.sum()
                loss.backward()
                self.assertEqual(out.get_device(), dev_id[0])
                self.assertEqual(out, expected_out)
                for expected, param in zip(expected_grads, l.parameters()):
                    self.assertEqual(param.grad, expected)

        # Check for None device_ids
        l = l.cuda()
        out = dp.data_parallel(l, i)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_sparse(self):
        l = nn.Embedding(10, 5, sparse=True).to("cuda:1")
        i = torch.randint(10, (20, 5), device="cuda:1", dtype=torch.long)
        expected_out = l(i)
        loss = expected_out.sum()
        loss.backward()
        expected_grads = []
        for param in l.parameters():
            expected_grads.append(param.grad.clone())
        dev_ids_list = [(0, 1), (1, 0)]
        for dev_id in dev_ids_list:
            with torch.cuda.device(dev_id[0]):
                l.cuda()
                l.zero_grad()
                out = dp.data_parallel(l, i, dev_id)
                loss = out.sum()
                loss.backward()
                self.assertEqual(out.get_device(), dev_id[0])
                self.assertEqual(out, expected_out)
                for expected, param in zip(expected_grads, l.parameters()):
                    self.assertEqual(param.grad.coalesce(), expected.coalesce())

        # Check for None device_ids
        l = l.cuda()
        out = dp.data_parallel(l, i)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_nested_output(self):
        def fn(input):
            return [
                input, (input.sin(), input.cos(), [input.add(1)]), input,
                OrderedDict(a=input, b=[input.sin()])
            ]

        class Net(nn.Module):
            def forward(self, input):
                return fn(input)

        i = torch.randn(2, 2).float().cuda(1)
        gpus = range(torch.cuda.device_count())
        output = dp.data_parallel(Net(), i, gpus)
        self.assertEqual(output, fn(i))
        self.assertIsInstance(output[0], torch.Tensor)
        self.assertIsInstance(output[1], tuple)
        self.assertIsInstance(output[1][0], torch.Tensor)
        self.assertIsInstance(output[1][1], torch.Tensor)
        self.assertIsInstance(output[1][2], list)
        self.assertIsInstance(output[1][2][0], torch.Tensor)
        self.assertIsInstance(output[2], torch.Tensor)
        self.assertIsInstance(output[3], dict)
        self.assertEqual(len(output[3]), 2)
        self.assertIn('a', output[3])
        self.assertIn('b', output[3])
        self.assertIsInstance(output[3]['a'], torch.Tensor)
        self.assertIsInstance(output[3]['b'], list)
        self.assertIsInstance(output[3]['b'][0], torch.Tensor)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_nested_input(self):
        def fn(input):
            return input[1][0]

        class Net(nn.Module):
            def forward(self, *input):
                return fn(input)

        i = torch.randn(20, 3, dtype=torch.float, device="cuda:1")
        input = (i.cos(), (i.sin(), i), i.sin())
        gpus = range(torch.cuda.device_count())
        output = dp.data_parallel(Net(), input, gpus)
        self.assertEqual(output, fn(input))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_module_zero_inputs(self):
        class TestModule(nn.Module):
            def forward(self):
                t = torch.eye(2, 3, device='cuda:0')
                return t + (1 - t)

        def test_helper(output, expected):
            self.assertEqual(output.get_device(), 0)
            self.assertEqual(output, expected)

        expected = torch.ones(2, 3, device='cuda:0')
        model = TestModule()

        test_helper(nn.DataParallel(model, [0])(), expected)
        test_helper(nn.DataParallel(model, [0, 1])(), expected)
        test_helper(dp.data_parallel(model, None, [0]), expected)
        test_helper(dp.data_parallel(model, (), [0, 1]), expected)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_device_args(self):
        cuda0 = torch.device('cuda:0')
        cuda1 = torch.device('cuda:1')

        # test output_device
        l = nn.Linear(10, 5).to(cuda0, torch.float)
        i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
        out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0)
        self.assertEqual(out, l(i))

        # test device_ids
        l = nn.Linear(10, 5).to(cuda0, torch.float)
        i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
        out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
        self.assertEqual(out, l(i))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_data_parallel_function_deletion(self):
        # this test case is originated from #16532
        def gradient_penalty(net, x):
            output = net(x)
            loss = torch.autograd.grad(
                outputs=output, inputs=x,
                grad_outputs=x.new_ones(output.size()),
                create_graph=True, retain_graph=True)[0].mean()
            return loss

        net = nn.Linear(4, 1).cuda()
        dpn = nn.DataParallel(net, [0, 1])
        x = torch.ones(2, 4, requires_grad=True).cuda()

        dpn.zero_grad()
        loss = gradient_penalty(dpn, x)
        loss.backward()
        grads = [p.grad for p in net.parameters()]
        self.assertEqual(2, len(grads))
        self.assertEqual(
            torch.tensor([[0.25, 0.25, 0.25, 0.25]], device='cuda:0'),
            grads[0])
        self.assertEqual(torch.tensor([0.0], device='cuda:0'), grads[1])

    def _test_scatter(self, tensor):
        x = tensor.detach().requires_grad_()
        result = dp.scatter(x, (0, 1))
        self.assertEqual(len(result), 2)
        self.assertEqual(result[0], x[:2])
        self.assertEqual(result[0].get_device(), 0)
        self.assertEqual(result[1], x[2:])
        self.assertEqual(result[1].get_device(), 1)
        grad = result[0].detach().clone().fill_(2)
        result[0].backward(grad)
        self.assertEqual(x.grad[:2], grad)
        self.assertEqual(x.grad[2:], grad.clone().zero_())
        _assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_scatter_cpu(self):
        self._test_scatter(torch.randn((4, 4), dtype=torch.double))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_scatter_gpu(self):
        self._test_scatter(torch.randn((4, 4), dtype=torch.double).cuda())

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
    @skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
    def test_data_parallel_complex(self):
        # We expect complex parameters to be broadcast by view_as_real, e.g. move from C to R^2
        class Cplx(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.cplx = torch.nn.Parameter(torch.zeros(1, 10, dtype=torch.cfloat).cuda())

            def forward(self, x):
                return x + self.cplx

        cplx = torch.nn.DataParallel(Cplx().cuda())
        input = torch.rand(1, 10, dtype=torch.cfloat).cuda()
        result = cplx(input)
        # 2 is the extra real view dimension here
        self.assertEqual(result.size(), torch.Size([1, 10, 2]))
        self.assertEqual(result, torch.view_as_real(input))

    def _test_gather(self, output_device):
        inputs = (
            torch.randn(2, 4, device='cuda:0', requires_grad=True, dtype=torch.double),
            torch.randn(2, 4, device='cuda:1', requires_grad=True, dtype=torch.double),
        )
        result = dp.gather(inputs, output_device)
        self.assertEqual(result.size(), torch.Size([4, 4]))
        self.assertEqual(result[:2], inputs[0])
        self.assertEqual(result[2:], inputs[1])
        if output_device != -1:
            self.assertEqual(result.get_device(), output_device)
        else:
            self.assertFalse(result.is_cuda)
        grad = torch.randn((4, 4), dtype=torch.double)
        if output_device != -1:
            grad = grad.cuda(output_device)
        result.backward(grad)
        self.assertEqual(inputs[0].grad, grad[:2])
        self.assertEqual(inputs[1].grad, grad[2:])
        _assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)

        # test scalar inputs, should stack into a vector in this case
        inputs = (
            torch.randn((), device='cuda:0', requires_grad=True, dtype=torch.double),
            torch.randn((), device='cuda:1', requires_grad=True, dtype=torch.double),
        )
        result = dp.gather(inputs, output_device)
        self.assertEqual(result.size(), torch.Size([2]))
        self.assertEqual(result[0], inputs[0])
        self.assertEqual(result[1], inputs[1])
        if output_device != -1:
            self.assertEqual(result.get_device(), output_device)
        else:
            self.assertFalse(result.is_cuda)
        grad = torch.randn(2, dtype=torch.double)
        if output_device != -1:
            grad = grad.cuda(output_device)
        result.backward(grad)
        self.assertEqual(inputs[0].grad, grad[0])
        self.assertEqual(inputs[1].grad, grad[1])
        _assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_gather_cpu(self):
        self._test_gather(-1)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_gather_gpu(self):
        self._test_gather(0)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_gather_different_len_dicts(self):
        inputs = (
            {'a': torch.randn(1, 2, requires_grad=True, device="cuda:0")},
            {
                'b': torch.randn(1, 2, requires_grad=True, device="cuda:1"),
                'a': torch.randn(1, 2, requires_grad=True, device="cuda:1"),
            }
        )
        with self.assertRaises(ValueError):
            _ = dp.gather(inputs, target_device=0)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_replicate(self):
        module = nn.Linear(10, 5).float().cuda()
        input = torch.randn(2, 10, dtype=torch.float, device="cuda")
        expected_output = module(input)
        for devices in [(0, 1), [0, 1]]:
            replicas = dp.replicate(module, devices)
            for i, replica in enumerate(replicas):
                for p in replica.parameters():
                    self.assertEqual(p.get_device(), i)
                replica_input = input.cuda(i)
                self.assertEqual(replica(replica_input), expected_output)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_replicate_buffers(self):
        net = nn.Module()
        net.bn = nn.BatchNorm2d(10)
        net.cuda()
        for devices in [(0, 1), [0, 1]]:
            replicas = dp.replicate(net, devices)
            for i, replica in enumerate(replicas):
                self.assertEqual(replica.bn.running_mean.get_device(), i, msg='buffer on wrong device')
                self.assertEqual(replica.bn.running_var.get_device(), i, msg='buffer on wrong device')
                self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, msg='buffer on wrong device')

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_zero_grad(self):
        # zero_grad should warn about using gradients inside forward

        class Net(torch.nn.Module):
            def __init__(self, testcase):
                super().__init__()
                self._testcase = testcase

            def forward(self, x):
                with self._testcase.assertWarnsRegex(
                        UserWarning,
                        r"Calling \.zero_grad\(\) from a module created with nn\.DataParallel\(\) has no effect."):
                    self.zero_grad()
                return x

        module = Net(self).cuda()
        dpm = dp.DataParallel(module)
        dpm(torch.rand(4, 3, 6, 5))

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_autocast(self):
        class Model(torch.nn.Linear):
            def __init__(self):
                super().__init__(8, 8)

            @torch.cuda.amp.autocast()
            def forward(self, input):
                return super().forward(input)

        model = dp.DataParallel(Model().cuda().to(dtype=torch.float32))
        input = torch.randn((8, 8), dtype=torch.float32, device="cuda")
        self.assertTrue(model(input).dtype is torch.float16)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_save_replica_module(self):
        # DataParallel replicas can be saved (gh-37182)
        module = torch.nn.Linear(8, 8).cuda()
        dpm = torch.nn.parallel.replicate(module, devices=[0, 1], detach=False)
        data = io.BytesIO()
        torch.save(dpm, data)
        dpm = torch.nn.parallel.replicate(module, devices=[0, 1], detach=True)
        torch.save(dpm, data)

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_strided_grad_layout(self):
        class ConvNet(nn.Module):
            def __init__(self, layouts, dtype_list):
                super().__init__()
                self.dtypes = dtype_list
                self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to(memory_format=layouts[0], dtype=dtype_list[0])
                self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to(memory_format=layouts[1], dtype=dtype_list[1])
                self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to(memory_format=layouts[2], dtype=dtype_list[2])
                self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to(memory_format=layouts[3], dtype=dtype_list[3])

            def forward(self, x):
                x = x.to(self.dtypes[0])
                x = self.conv0(x).to(self.dtypes[1])
                x = self.conv1(x).to(self.dtypes[2])
                x = self.conv2(x).to(self.dtypes[3])
                x = self.conv3(x)
                return x

        layer_formats = ([torch.contiguous_format] * 4,
                         [torch.channels_last] * 2 + [torch.contiguous_format] * 2,
                         [torch.channels_last] * 4,)
        layer_dtypes = ([torch.float] * 4,
                        [torch.float] * 2 + [torch.half] * 2,
                        [torch.half] * 4,)

        ndevs = torch.cuda.device_count()
        input = torch.randn(ndevs * 8, 8, 8, 8, device="cuda:0", dtype=torch.float)
        target = torch.randn(ndevs * 8, 8, 4, 4, device="cuda:0", dtype=torch.float)
        device_ids = list(range(ndevs))

        with torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False):
            for formats, dtype_list in product(layer_formats, layer_dtypes):
                model_msg = f"formats = {formats} dtypes = {dtypes}"
                try:
                    m = ConvNet(formats, dtype_list).cuda(device="cuda:0")
                    m_dp = dp.DataParallel(deepcopy(m), device_ids=device_ids)
                    opt = torch.optim.SGD(m.parameters(), lr=0.1)
                    opt_dp = torch.optim.SGD(m_dp.parameters(), lr=0.1)
                    has_half = any(p.dtype is torch.half for p in m.parameters())
                    tol = 1.e-3 if has_half else 1.e-5
                except BaseException:
                    # Prints case-specific debugging info to narrow down failing case.
                    print("Caught exception during model creation for " + model_msg, flush=True)
                    raise
                # 2 iters:  First iter creates grads, second iter tries zeroed grads.
                for it in range(2):
                    iter_msg = f"iter = {it} " + model_msg
                    named_msg = iter_msg
                    try:
                        F.mse_loss(m(input).float(), target).backward()
                        F.mse_loss(m_dp(input).float(), target).backward()
                        for i, ((layer_name, m_child), m_dp_child) in enumerate(zip(m.named_children(),
                                                                                    m_dp.module.children())):
                            named_msg = layer_name + ".weight " + iter_msg
                            self.assertTrue(m_child.weight.grad.is_contiguous(memory_format=formats[i]), named_msg)
                            self.assertTrue(m_dp_child.weight.grad.is_contiguous(memory_format=formats[i]), named_msg)
                            for j, ((param_name, p), p_dp) in enumerate(zip(m_child.named_parameters(),
                                                                            m_dp_child.parameters())):
                                named_msg = layer_name + "." + param_name + " " + iter_msg
                                self.assertEqual(p.grad, p_dp.grad, rtol=tol, atol=tol)
                        opt.step()
                        opt_dp.step()
                        opt.zero_grad()
                        opt_dp.zero_grad()
                    except BaseException:
                        # Makes sure we still get info if an error occurred somewhere other than the asserts.
                        print("Caught exception during iterations at " + named_msg, flush=True)
                        raise

    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
    def test_parameter_list_dict_replica(self):
        class MyMod(torch.nn.Module):
            def __init__(self, data, check_fn):
                super().__init__()
                self.data = data
                self.check_fn = check_fn

            def forward(self, inp):
                self.check_fn(self)
                return inp

        p1 = torch.nn.Parameter(torch.rand(10))
        p2 = torch.nn.Parameter(torch.rand(10))
        key0 = 0
        key1 = 1

        def check_fn(self_):
            self.assertEqual(p1, self_.data[key0])
            self.assertEqual(p2, self_.data[key1])
            self.assertTrue(self_.data[key0].requires_grad)
            self.assertTrue(self_.data[key1].requires_grad)
            self.assertIsNotNone(self_.data[key0].grad_fn)
            self.assertIsNotNone(self_.data[key1].grad_fn)

        module = MyMod(torch.nn.ParameterList([p1, p2]), check_fn).cuda()
        model = dp.DataParallel(module)
        input = torch.randn((8, 8), device="cuda")

        # Runs the check_fn
        model(input)

        key0 = "0"
        key1 = "1"
        module = MyMod(torch.nn.ParameterDict({"0": p1, "1": p2}), check_fn).cuda()
        model = dp.DataParallel(module)
        input = torch.randn((8, 8), device="cuda")

        # Runs the check_fn
        model(input)


class TestDataParallelDeviceType(TestCase):

    @onlyCUDA
    @skipMeta
    @dtypes(torch.float, torch.double, torch.half)
    def test_data_parallel_module(self, device, dtype):
        l = nn.Linear(10, 5).to(device, dtype)
        i = torch.randn(20, 10, device=device, dtype=dtype)
        expected_out = l(i)
        net = nn.DataParallel(l)
        out = net(i)
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)

    @onlyCUDA
    @skipMeta
    @dtypes(torch.float, torch.double, torch.half)
    def test_data_parallel_module_kwargs_only(self, device, dtype):
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.l = l

            def forward(self, input):
                return self.l(input)

        l = nn.Linear(10, 5).to(device, dtype)
        i = torch.randn(20, 10, device=device, dtype=dtype)
        expected_out = l(i)
        n = nn.DataParallel(Net())
        out = n(input=i)
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)

    @onlyCUDA
    @skipMeta
    @dtypes(torch.float, torch.double, torch.half)
    def test_data_parallel_module_kwargs_only_empty_list(self, device, dtype):
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.l = l

            def forward(self, input):
                return self.l(input['data'])

        l = nn.Linear(10, 5).to(device, dtype)
        i = torch.randn(20, 10, device=device, dtype=dtype)
        expected_out = l(i)
        n = nn.DataParallel(Net())
        out = n(input={'data': i, 'unused': []})
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)

    @onlyCUDA
    @skipMeta
    @dtypes(torch.float, torch.double, torch.half)
    def test_data_parallel_module_kwargs_only_empty_dict(self, device, dtype):
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.l = l

            def forward(self, input):
                return self.l(input['data'])

        l = nn.Linear(10, 5).to(device, dtype)
        i = torch.randn(20, 10, device=device, dtype=dtype)
        expected_out = l(i)
        n = nn.DataParallel(Net())
        out = n(input={'data': i, 'unused': {}})
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)

    @onlyCUDA
    @skipMeta
    @dtypes(torch.float, torch.double, torch.half)
    def test_data_parallel_module_kwargs_only_empty_tuple(self, device, dtype):
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.l = l

            def forward(self, input):
                return self.l(input['data'])

        l = nn.Linear(10, 5).to(device, dtype)
        i = torch.randn(20, 10, device=device, dtype=dtype)
        expected_out = l(i)
        n = nn.DataParallel(Net())
        out = n(input={'data': i, 'unused': ()})
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)


instantiate_device_type_tests(TestDataParallelDeviceType, globals())

if __name__ == '__main__':
    TestCase._default_dtype_check_enabled = True
    run_tests()
